{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8NNfwWXGvn_o",
        "output": {
          "id": 383783884102102,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Installation\n",
        "If you haven't installed Pearl, please make sure you install Pearl with the following cell. Otherwise, you can skip the cell below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1uLHbYlegKX-"
      },
      "outputs": [],
      "source": [
        "%pip uninstall Pearl -y\n",
        "%rm -rf Pearl\n",
        "!git clone https://github.com/facebookresearch/Pearl.git\n",
        "%cd Pearl\n",
        "%pip install .\n",
        "%cd .."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Import Modules"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vcb70ZC_h3OA"
      },
      "outputs": [],
      "source": [
        "from pearl.neural_networks.common.value_networks import EnsembleQValueNetwork\n",
        "from pearl.replay_buffers.sequential_decision_making.bootstrap_replay_buffer import BootstrapReplayBuffer\n",
        "from pearl.policy_learners.sequential_decision_making.bootstrapped_dqn import BootstrappedDQN\n",
        "from pearl.utils.functional_utils.experimentation.set_seed import set_seed\n",
        "from pearl.action_representation_modules.identity_action_representation_module import IdentityActionRepresentationModule\n",
        "from pearl.history_summarization_modules.lstm_history_summarization_module import LSTMHistorySummarizationModule\n",
        "from pearl.policy_learners.sequential_decision_making.deep_q_learning import DeepQLearning\n",
        "from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import FIFOOffPolicyReplayBuffer\n",
        "from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning\n",
        "from pearl.pearl_agent import PearlAgent\n",
        "from pearl.tutorials.single_item_recommender_system_example.env_model import SequenceClassificationModel\n",
        "from pearl.tutorials.single_item_recommender_system_example.env import RecEnv\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "set_seed(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load Environment\n",
        "This environment's underlying model is trained using the MIND dataset (Wu et al. 2020).\n",
        "\n",
        "Each data point:\n",
        "- A history of impressions clicked by a user\n",
        "- Each impression is represented by an 100-dim vector\n",
        "- A list of impressions and whether or not they are clicked\n",
        "\n",
        "The environment is constructed with the following setup. Note that this example is a contrived example to illustrate Pearl's usage, agent modularity and a subset of features. Not to represent a real-world environment or problem.  \n",
        "- State: a history of impressions by a user (note that we used the history of impressions of instead of clicked impressions to speed up learning in this example. Interested Pearl users can change it to history of clicked impressions with much longer episode length and samples to run the following experiments.)\n",
        "- Dynamic action space: two randomly picked impressions\n",
        "- Action: one of the two impressions\n",
        "- Reward: click\n",
        "- Reset every 20 steps.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g1VHtmldi3A2",
        "output": {
          "id": 1038395970722928,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [],
      "source": [
        "# load environment\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model = SequenceClassificationModel(100).to(device)\n",
        "model.load_state_dict(torch.load(\"Pearl/pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt\"))\n",
        "actions = torch.load(\"Pearl/pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt\")\n",
        "env = RecEnv(list(actions.values())[:100], model)\n",
        "observation, action_space = env.reset()\n",
        "\n",
        "# experiment code\n",
        "number_of_steps = 100000\n",
        "record_period = 400"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Vanilla DQN Agent\n",
        "Able to handle dynamic action space but not able to handle partial observability and sparse reward."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kulkpFAvnOQx"
      },
      "outputs": [],
      "source": [
        "# create a pearl agent\n",
        "\n",
        "action_representation_module = IdentityActionRepresentationModule(\n",
        "    max_number_actions=action_space.n,\n",
        "    representation_dim=action_space.action_dim,\n",
        ")\n",
        "\n",
        "# DQN-vanilla\n",
        "agent = PearlAgent(\n",
        "    policy_learner=DeepQLearning(\n",
        "        state_dim=1,\n",
        "        action_space=action_space,\n",
        "        hidden_dims=[64, 64],\n",
        "        training_rounds=50,\n",
        "        action_representation_module=action_representation_module,\n",
        "    ),\n",
        "    replay_buffer=FIFOOffPolicyReplayBuffer(100_000),\n",
        "    device_id=-1,\n",
        ")\n",
        "\n",
        "info = online_learning(\n",
        "    agent=agent,\n",
        "    env=env,\n",
        "    number_of_steps=number_of_steps,\n",
        "    print_every_x_steps=100,\n",
        "    record_period=record_period,\n",
        "    learn_after_episode=True,\n",
        ")\n",
        "torch.save(info[\"return\"], \"DQN-return.pt\")\n",
        "plt.plot(record_period * np.arange(len(info[\"return\"])), info[\"return\"], label=\"DQN\")\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## DQN Agent with LSTM history summarization module\n",
        "\n",
        "Now the DQN agent can handle partially observable environments with history summarization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cDauzO74nS4c"
      },
      "outputs": [],
      "source": [
        "# Add a LSTM history summarization module\n",
        "\n",
        "agent = PearlAgent(\n",
        "    policy_learner=DeepQLearning(\n",
        "        state_dim=128,\n",
        "        action_space=action_space,\n",
        "        hidden_dims=[64, 64],\n",
        "        training_rounds=50,\n",
        "        action_representation_module=action_representation_module,\n",
        "    ),\n",
        "    history_summarization_module=LSTMHistorySummarizationModule(\n",
        "        observation_dim=1,\n",
        "        action_dim=100,\n",
        "        hidden_dim=128,\n",
        "        history_length=8,\n",
        "    ),\n",
        "    replay_buffer=FIFOOffPolicyReplayBuffer(100_000),\n",
        "    device_id=-1,\n",
        ")\n",
        "\n",
        "info = online_learning(\n",
        "    agent=agent,\n",
        "    env=env,\n",
        "    number_of_steps=number_of_steps,\n",
        "    print_every_x_steps=100,\n",
        "    record_period=record_period,\n",
        "    learn_after_episode=True,\n",
        ")\n",
        "torch.save(info[\"return\"], \"DQN-LSTM-return.pt\")\n",
        "plt.plot(record_period * np.arange(len(info[\"return\"])), info[\"return\"], label=\"DQN-LSTM\")\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Bootstrapped DQN Agent with LSTM History Summarization\n",
        "\n",
        "Leveraging the deep exploration value-based algorithm, now the agent can achieve a better performance in a much faster way while being able to still leverage history summarization capability. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_7Cpzoi3nVAw"
      },
      "outputs": [],
      "source": [
        "# Better exploration with BootstrappedDQN-LSTM\n",
        "\n",
        "agent = PearlAgent(\n",
        "    policy_learner=BootstrappedDQN(\n",
        "        q_ensemble_network=EnsembleQValueNetwork(\n",
        "            state_dim=128,\n",
        "            action_dim=100,\n",
        "            ensemble_size=10,\n",
        "            output_dim=1,\n",
        "            hidden_dims=[64, 64],\n",
        "            prior_scale=0.3,\n",
        "        ),\n",
        "        action_space=action_space,\n",
        "        training_rounds=50,\n",
        "        action_representation_module=action_representation_module,\n",
        "    ),\n",
        "    history_summarization_module=LSTMHistorySummarizationModule(\n",
        "        observation_dim=1,\n",
        "        action_dim=100,\n",
        "        hidden_dim=128,\n",
        "        history_length=8,\n",
        "    ),\n",
        "    replay_buffer=BootstrapReplayBuffer(100_000, 1.0, 10),\n",
        "    device_id=-1,\n",
        ")\n",
        "\n",
        "info = online_learning(\n",
        "    agent=agent,\n",
        "    env=env,\n",
        "    number_of_steps=number_of_steps,\n",
        "    print_every_x_steps=100,\n",
        "    record_period=record_period,\n",
        "    learn_after_episode=True,\n",
        ")\n",
        "torch.save(info[\"return\"], \"BootstrappedDQN-LSTM-return.pt\")\n",
        "plt.plot(record_period * np.arange(len(info[\"return\"])), info[\"return\"], label=\"BootstrappedDQN-LSTM\")\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Summary\n",
        "In this example, we illustrated Pearl's capability of dealing with dynamic action space, standard policy learning, history summarization and intelligent exploration, all in a single agent. By running the code above, you should be able to get agent performance results similar to the figure shown in pearl/tutorials/single_item_recommender_system_example/dqn+lstm+deep_explore.png.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": []
    }
  ],
  "metadata": {
    "custom": {
      "cells": [],
      "metadata": {
        "custom": {
          "cells": [],
          "metadata": {
            "accelerator": "GPU",
            "colab": {
              "gpuType": "T4",
              "include_colab_link": true,
              "provenance": []
            },
            "fileHeader": "",
            "fileUid": "4316417e-7688-45f2-a94f-24148bfc425e",
            "isAdHoc": false,
            "kernelspec": {
              "display_name": "pearl (local)",
              "language": "python",
              "name": "pearl_local"
            },
            "language_info": {
              "name": "python"
            }
          },
          "nbformat": 4,
          "nbformat_minor": 2
        },
        "fileHeader": "",
        "fileUid": "1158a851-91bb-437e-a391-aba92448f600",
        "indentAmount": 2,
        "isAdHoc": false,
        "language_info": {
          "name": "plaintext"
        }
      },
      "nbformat": 4,
      "nbformat_minor": 2
    },
    "indentAmount": 2
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
